{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ColBERTv2: Indexing & Search Notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by importing the relevant classes. As we'll see below, `Indexer` and `Searcher` are the key actors here. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, '../')\n",
    "\n",
    "from colbert.infra import Run, RunConfig, ColBERTConfig\n",
    "from colbert.data import Queries, Collection\n",
    "from colbert import Indexer, Searcher"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The workflow here assumes an IR dataset: a set of queries and a corresponding collection of passages.\n",
    "\n",
    "The classes `Queries` and `Collection` provide a convenient interface for working with such datasets.\n",
    "\n",
    "We will use the *dev set* of the **LoTTE benchmark** we recently introduced in the ColBERTv2 paper. The dev and test sets contain several domain-specific corpora, and we'll use the smallest dev set corpus, namely `lifestyle:dev`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p downloads/\n",
    "\n",
    "# ColBERTv2 checkpoint trained on MS MARCO Passage Ranking (388MB compressed)\n",
    "!wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz -P downloads/\n",
    "!tar -xvzf downloads/colbertv2.0.tar.gz -C downloads/\n",
    "\n",
    "# The LoTTE dev and test sets (3.4GB compressed)\n",
    "!wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz -P downloads/\n",
    "!tar -xvzf downloads/lotte.tar.gz -C downloads/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataroot = 'downloads/lotte'\n",
    "dataset = 'lifestyle'\n",
    "datasplit = 'dev'\n",
    "\n",
    "queries = os.path.join(dataroot, dataset, datasplit, 'questions.search.tsv')\n",
    "collection = os.path.join(dataroot, dataset, datasplit, 'collection.tsv')\n",
    "\n",
    "queries = Queries(path=queries)\n",
    "collection = Collection(path=collection)\n",
    "\n",
    "f'Loaded {len(queries)} queries and {len(collection):,} passages'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This loaded 417 queries and 269k passages. Let's inspect one query and one passage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(queries[24])\n",
    "print()\n",
    "print(collection[89852])\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Indexing\n",
    "\n",
    "For efficient search, we can pre-compute the ColBERT representation of each passage and index them.\n",
    "\n",
    "Below, the `Indexer` take a model checkpoint and writes a (compressed) index to disk. We then prepare a `Searcher` for retrieval from this index.\n",
    "\n",
    "(With four Titan V GPUs, indexing should take about 13 minutes. The output is fairly long/ugly at the moment!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nbits = 2   # encode each dimension with 2 bits\n",
    "doc_maxlen = 300   # truncate passages at 300 tokens\n",
    "\n",
    "checkpoint = 'downloads/colbertv2.0'\n",
    "index_name = f'{dataset}.{datasplit}.{nbits}bits'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with Run().context(RunConfig(nranks=4, experiment='notebook')):  # nranks specifies the number of GPUs to use.\n",
    "    config = ColBERTConfig(doc_maxlen=doc_maxlen, nbits=nbits)\n",
    "\n",
    "    indexer = Indexer(checkpoint=checkpoint, config=config)\n",
    "    indexer.index(name=index_name, collection=collection, overwrite=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indexer.get_index() # You can get the absolute path of the index, if needed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Search\n",
    "\n",
    "Having built the index and prepared our `searcher`, we can search for individual query strings.\n",
    "\n",
    "We can use the `queries` set we loaded earlier — or you can supply your own questions. Feel free to get creative! But keep in mind this set of ~300k lifestyle passages can only answer a small, focused set of questions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To create the searcher using its relative name (i.e., not a full path), set\n",
    "# experiment=value_used_for_indexing in the RunConfig.\n",
    "with Run().context(RunConfig(experiment='notebook')):\n",
    "    searcher = Searcher(index=index_name)\n",
    "\n",
    "\n",
    "# If you want to customize the search latency--quality tradeoff, you can also supply a\n",
    "# config=ColBERTConfig(ncells=.., centroid_score_threshold=.., ndocs=..) argument.\n",
    "# The default settings with k <= 10 (1, 0.5, 256) gives the fastest search,\n",
    "# but you can gain more extensive search by setting larger values of k or\n",
    "# manually specifying more conservative ColBERTConfig settings (e.g. (4, 0.4, 4096))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = queries[37]   # or supply your own query\n",
    "\n",
    "print(f\"#> {query}\")\n",
    "\n",
    "# Find the top-3 passages for this query\n",
    "results = searcher.search(query, k=3)\n",
    "\n",
    "# Print out the top-k retrieved passages\n",
    "for passage_id, passage_rank, passage_score in zip(*results):\n",
    "    print(f\"\\t [{passage_rank}] \\t\\t {passage_score:.1f} \\t\\t {searcher.collection[passage_id]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Search\n",
    "\n",
    "In many applications, you have a large batch of queries and you need to maximize the overall throughput. For that, you can use the `searcher.search_all(queries, k)` method, which returns a `Ranking` object that organizes the results across all queries.\n",
    "\n",
    "(Batching provides many opportunities for higher-throughput search, though we have not implemented most of those optimizations for compressed indexes yet.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rankings = searcher.search_all(queries, k=5).todict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rankings[30]  # For query 30, a list of (passage_id, rank, score) for the top-k passages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a99ac6d2deb03d0b7ced3594556c328848678d7cea021ae1b9990e15d3ad5c49"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
